{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regression Models in Selene\n",
    "\n",
    "Selene is a flexible framework, and can be used for tasks beyond simple classification.\n",
    "This tutorial demonstrates the simple process of training regression models with Selene.\n",
    "For this example, we will predict mean ribosomal load (MRL) from 50 base pair 5' UTR sequences using models and data from [*Human 5′ UTR design and variant effect prediction from a massively parallel translation assay*](https://doi.org/10.1101/310375) by Sample et al.\n",
    "This data was generated from a massively parallel reporter assay (MPRA), which you can read more about it in the preprint on [*bioRxiv*](https://doi.org/10.1101/310375).\n",
    "\n",
    "## Setup\n",
    "\n",
    "**Architecture:** The model is defined in [utr_model.py](https://github.com/FunctionLab/selene/blob/master/tutorials/regression_mpra_example/utr_model.py), and only superficially differs from the model in [the paper](https://doi.org/10.1101/310375).\n",
    "Since this is a real-valued regression problem, it is appropriate that the `criterion` method in `utr_model.py` uses the mean squared error.\n",
    "\n",
    "\n",
    "**Data:** The data from Sample et al is available on the [Gene Expression Omnibus (GEO)](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE114002).\n",
    "However, we have included [the `download_data.py` script](https://github.com/FunctionLab/selene/blob/master/tutorials/regression_mpra_example/download_data.py), to download the data and preprocess it.\n",
    "It should produce three files, `train.mat`, `validate.mat`, and `test.mat`.\n",
    "They include the data for training, validation, and testing respectively.\n",
    "At present, regression models can only be trained with `*.mat` files and the [`MatFileSampler`](http://selene.flatironinstitute.org/samplers.file_samplers.html#selene_sdk.samplers.file_samplers.MatFileSampler).\n",
    "Further, a `MatFileSampler` must be specified for each of the `train.mat`, `validate.mat`, and `test.mat` files.\n",
    "These `MatFileSampler`s are then used for the `train`, `validate`, and `test` arguments of a [`MultiFileSampler`](http://selene.flatironinstitute.org/samplers.html#selene_sdk.samplers.MultiFileSampler).\n",
    "The specific syntax is demonstrated in the configuration file, [`regression_train.yml`](https://github.com/FunctionLab/selene/blob/master/tutorials/regression_mpra_example/regression_train.yml).\n",
    "\n",
    "**Configuration file:** The configuration file [`regression_train.yml`](https://github.com/FunctionLab/selene/blob/master/tutorials/regression_mpra_example/regression_train.yml) is slightly different than the configuration files in the classification tutorials.\n",
    "Specifically, `metrics` in `train_model` includes the coefficient of determination (`r2`), since the default metrics (`roc_auc` and `average_precision`) are not appropriate for regression.\n",
    "Further, `report_gt_feature_n_positives` in `train_model` has been set to zero to prevent spurious filtering based on target values.\n",
    "\n",
    "## Download the data\n",
    "\n",
    "To download the data, just run the [`download_data.py`](https://github.com/FunctionLab/selene/blob/master/tutorials/regression_mpra_example/download_data.py) script from the command line:\n",
    "```sh\n",
    "python download_data.py\n",
    "```\n",
    "\n",
    "## Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from selene_sdk.utils import load_path\n",
    "from selene_sdk.utils import parse_configs_and_run"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before running `load_path` on `regression_train.yml`, please edit the YAML file to include the absolute path of the model file.\n",
    "\n",
    "Currently, the model is set to train on GPU.\n",
    "If you do not have CUDA on your machine, please set `use_cuda` to `False` in the configuration file. Note that using the CPU instead of GPU will slow down training considerably."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "configs = load_path(\"./regression_train.yml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Outputs and logs saved to ./2018-12-09-15-53-59\n",
      "2018-12-09 15:54:01,335 - Creating validation dataset.\n",
      "2018-12-09 15:54:01,361 - 0.02456068992614746 s to load 20096 validation examples (157 validation batches) to evaluate after each training step.\n",
      "2018-12-09 15:54:24,581 - [STEP 2031] average number of steps per second: 88.0\n",
      "2018-12-09 15:54:25,020 - validation r2: 0.8104067907778664\n",
      "2018-12-09 15:54:25,125 - training loss: 0.2401450276374817\n",
      "2018-12-09 15:54:25,126 - validation loss: 0.18883540832502826\n",
      "2018-12-09 15:54:47,288 - [STEP 4062] average number of steps per second: 91.9\n",
      "2018-12-09 15:54:47,729 - validation r2: 0.8564685296471333\n",
      "2018-12-09 15:54:47,822 - training loss: 0.193187415599823\n",
      "2018-12-09 15:54:47,823 - validation loss: 0.14294122951995036\n",
      "2018-12-09 15:55:09,855 - [STEP 6093] average number of steps per second: 92.5\n",
      "2018-12-09 15:55:10,290 - validation r2: 0.8653072068202623\n",
      "2018-12-09 15:55:10,376 - training loss: 0.2143666297197342\n",
      "2018-12-09 15:55:10,377 - validation loss: 0.13429565685000389\n",
      "2018-12-09 15:55:32,461 - Creating test dataset.\n",
      "2018-12-09 15:55:32,490 - 0.025864839553833008 s to load 20096 test examples (157 test batches) to evaluate after all training steps.\n",
      "2018-12-09 15:55:32,950 - test r2: 0.9016817574999407\n"
     ]
    }
   ],
   "source": [
    "parse_configs_and_run(configs, lr=0.001)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test-selene-env",
   "language": "python",
   "name": "test-selene-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
